import numpy as np
import torch

def vectorized_correlation(x,y):
	dim = 0

	centered_x = x - x.mean(axis=dim, keepdims=True)
	centered_y = y - y.mean(axis=dim, keepdims=True)

	covariance = (centered_x * centered_y).sum(axis=dim, keepdims=True)

	bessel_corrected_covariance = covariance / (x.shape[dim] - 1)

	x_std = x.std(axis=dim, keepdims=True)+1e-8
	y_std = y.std(axis=dim, keepdims=True)+1e-8

	corr = bessel_corrected_covariance / (x_std * y_std)

	return corr.ravel()

class OLS_pytorch(object):
	def __init__(self,use_gpu=False):
		self.coefficients = []
		self.use_gpu = use_gpu
		self.X = None
		self.y = None

	def fit(self,X,y):
		if len(X.shape) == 1:
			X = self._reshape_x(X)
		if len(y.shape) == 1:
			y = self._reshape_x(y)

		X = self._concatenate_ones(X)

		X = torch.from_numpy(X).float()
		y = torch.from_numpy(y).float()
		if self.use_gpu:
			X = X.cuda()
			y = y.cuda()
		XtX = torch.matmul(X.t(),X)
		Xty = torch.matmul(X.t(),y.unsqueeze(2))
		XtX = XtX.unsqueeze(0)
		XtX = torch.repeat_interleave(XtX, y.shape[0], dim=0)
		betas_cholesky, _ = torch.solve(Xty, XtX)

		self.coefficients = betas_cholesky
		return betas_cholesky

	def predict(self, entry):
		if len(entry.shape) == 1:
			entry = self._reshape_x(entry)
		entry =  self._concatenate_ones(entry)
		entry = torch.from_numpy(entry).float()
		if self.use_gpu:
			entry = entry.cuda()
		prediction = torch.matmul(entry,self.coefficients)
		prediction = prediction.cpu().numpy()
		prediction = np.squeeze(prediction).T
		return prediction

	def score(self):
		prediction = torch.matmul(self.X,self.coefficients)
		prediction = prediction
		yhat = prediction
		ybar = (torch.sum(self.y,dim=1, keepdim=True)/self.y.shape[1]).unsqueeze(2)
		ssreg = torch.sum((yhat-ybar)**2,dim=1, keepdim=True)
		sstot = torch.sum((self.y.unsqueeze(2) - ybar)**2,dim=1, keepdim=True)
		score = ssreg / sstot
		return score.cpu().numpy().ravel()

	def _reshape_x(self,X):
		return X.reshape(-1,1)

	def _concatenate_ones(self,X):
		ones = np.ones(shape=X.shape[0]).reshape(-1,1)
		return np.concatenate((ones,X),1)
